library(lfe)
library(monomvn)
library(glmnet)
library(hdm)

boot_iid <- function(x, est_fun, pmax, pmin, y_est, covars=NULL, treat_est=NULL, nsims, form=form){
  samp <- sample(1:nrow(pmax),nrow(pmax), replace=T)
  if(is.null(treat_est)){
    if(is.null(covars)){
      est_fun(pmax=as.matrix(pmax[samp,]), pmin=as.matrix(pmin[samp,]), y_est=y_est[samp], covars_est=NULL, nsims=nsims,form=form) 
    }else{
      est_fun(pmax=as.matrix(pmax[samp,]), pmin=as.matrix(pmin[samp,]), y_est=y_est[samp], covars_est=covars[samp,], nsims=nsims,form=form) 
    }
  }else{
    if(is.null(covars)){
      est_fun(pmax=as.matrix(pmax[samp,]), pmin=as.matrix(pmin[samp,]), y_est=y_est[samp], covars_est=NULL, nsims=nsims,form=form, treat_est) 
    }else{
      est_fun(pmax=as.matrix(pmax[samp,]), pmin=as.matrix(pmin[samp,]), y_est=y_est[samp], covars_est=covars[samp,], nsims=nsims,form=form, treat_est=treat_est) 
    }
  }
}

boot_clust <- function(x, est_fun,clust, pmax, pmin, y_est, covars=NULL, treat_est=NULL, nsims, form=form){
  to_ret <- NA
#  while(is.na(to_ret)){
    clust_draw <- sample(unique(clust),length(unique(clust)), replace=T)
    samp <- unlist(lapply(clust_draw, function(x){which(clust==x)}))
    if(is.null(treat_est)){
      to_ret <- est_fun(pmax=as.matrix(pmax[samp,]), pmin=as.matrix(pmin[samp,]), y_est=y_est[samp], covars_est=covars[samp,], nsims=nsims,form=form) 
    }else{
      to_ret <- est_fun(pmax=as.matrix(pmax[samp,]), pmin=as.matrix(pmin[samp,]), y_est=y_est[samp], covars_est=covars[samp,], nsims=nsims,form=form, treat_est=treat_est) 
    }
 # }
  to_ret
}


get_resids <- function(in_max, in_min, y_est=y_est, covars_est=NULL, form){
  keep <- xor(in_max,  in_min)
  if(is.null(covars_est)){
    ry <- resid(felm(as.formula(paste0('y_est[keep]~1', form))))
    rt <- resid(felm(as.formula(paste0('in_max[keep]~1', form))))
  }else{
    ry <- resid(felm(as.formula(paste0('y_est[keep]~1', form)) , data=covars_est[keep,]))
    rt <- resid(felm(as.formula(paste0('in_max[keep]~1', form)) , data=covars_est[keep,]))
  }
  
  to_return <- rep(0, length(y_est))
  to_return[keep] <- ry*rt/sum(rt^2)
  to_return
}


get_resids_het <- function(in_max, in_min, y_est=y_est, covars_est=NULL, form, treat_est){
  keep <- xor(in_max,  in_min)
  if(is.null(covars_est)){
    ry_max <- resid(felm(as.formula('y_est[keep & in_max]~1')))
    rt_max <- resid(felm(as.formula('treat_est[keep & in_max]~1')))
    
    ry_min <- resid(felm(as.formula('y_est[keep & in_min]~1')))
    rt_min <- resid(felm(as.formula('treat_est[keep & in_min]~1')))
    
  }else{
    ry_max <- resid(felm(as.formula(paste0('y_est[keep & in_max]~', form)), covars_est[keep & in_min,]))
    rt_max <- resid(felm(as.formula(paste0('treat_est[keep & in_max]~', form)), covars_est[keep & in_min,]))
    
    ry_min <- resid(felm(as.formula(paste0('y_est[keep & in_min]~', form)), covars_est[keep & in_min,]))
    rt_min <- resid(felm(as.formula(paste0('treat_est[keep & in_min]~', form)), covars_est[keep & in_min,]))
  }
  
  to_return <- rep(0, length(y_est))
  to_return[keep & in_max] <- ry_max*rt_max/sum(rt_max^2)
  to_return[keep & in_min] <- to_return[keep & in_min] + -ry_min*rt_min/sum(rt_min^2)
  to_return
}


get_mcseq_est_set_ls <- function(pmax, pmin, y_est, covars_est=NULL, nsims, form){
  to_fill <- matrix(NA, nrow=length(y_est), ncol=nsims)  
    for(j in 1:nsims){
      to_fill[,j] <- get_resids(in_max=pmax[,j] > 0, in_min=pmin[,j] > 0, y_est=y_est, covars_est, form)
    }
  vec <- rowMeans(to_fill)
  sum(vec)
}


get_mcseq_est_set_het_ls <- function(pmax, pmin, y_est, covars_est=NULL, nsims, form, treat_est=treat_est){
  to_fill <- matrix(NA, nrow=length(y_est), ncol=nsims)  
  for(j in 1:nsims){
    to_fill[,j] <- get_resids_het(in_max=pmax[,j] > 0, in_min=pmin[,j] > 0, y_est=y_est, covars_est, form, treat_est=treat_est)
  }
  vec <- rowMeans(to_fill)
  sum(vec)
}



 est_treat_mcseq_ls <- function(t_split, t_est, y_split, y_est, clust, form, covars_est=NULL, covars_split=NULL, nsims, nboot, q, treat_est, treat_split, n_num_covars=0){
  split_mod <- felm(as.formula(paste0('y_split ~ t_split', form)), covars_split)

  beta <- coef(split_mod)
  .Sigma <- vcov(split_mod)
  
  if(any(grepl("(Intercept)", names(beta)))){
    beta <- beta[-1]
    .Sigma <- .Sigma[-1,-1]
  }
  
  .Sigma <- .Sigma[1:(length(beta) - n_num_covars), 1:(length(beta) - n_num_covars)]
  beta <- beta[1:(length(beta) - n_num_covars)]
  
  keep <- !is.na(beta) 
  .Sigma <- .Sigma[keep, keep]
  beta <- beta[keep]
  betas <- mvrnorm(n = nsims, mu=beta, Sigma = .Sigma, tol = 1e-6, empirical = FALSE, EISPACK = FALSE)

#  resids <- resid(felm(as.formula(paste0('t_est ~ 0', form)), data=covars_est))
  
  ypreds <- t_est[,keep] %*% t(betas)
  ypreds <- resid(felm(as.formula(paste0('ypreds~0',form)), data=covars_est))
  
  pmin <- apply(ypreds, 2, function(x){x <= quantile(x, q)})/nsims
  pmax <- apply(ypreds, 2, function(x){x >= quantile(x, 1-q)})/nsims
  
  est <- get_mcseq_est_set_ls(pmax, pmin, y_est, covars_est, nsims, form)
  if(is.null(clust)){  
    boots <- unlist(lapply(1:nboot, boot_iid, est_fun=get_mcseq_est_set_ls, pmax=pmax, 
                           pmin=pmin, y_est=y_est, covars=covars_est, nsims=nsims,form=form))
  }else{
    boots <- unlist(lapply(1:nboot, boot_clust, est_fun=get_mcseq_est_set_ls, clust=clust,pmax=pmax, 
                           pmin=pmin, y_est=y_est, covars=covars_est, nsims=nsims,form=form))
    
  }
  return(list(est=est, v=var(boots, na.rm=T)))
 }

 
 
 est_het_mcseq_ls <- function(t_split, t_est, y_split, y_est, clust, form, covars_est=NULL, covars_split=NULL, nsims, nboot, q, treat_split, treat_est){
    if(!is.null(covars_split)){
      split_mod <- felm(as.formula(paste0('y_split ~ treat_split*t_split', form)), covars_split)
    }else{
      split_mod <- felm(as.formula('y_split ~ treat_split*t_split'))
    }   
   beta <- coef(split_mod)
   .Sigma <- vcov(split_mod)[(length(beta)-ncol(t_split) + 1):length(beta),(length(beta)-ncol(t_split) + 1):length(beta)]
   beta <- beta[(length(beta)-ncol(t_split) + 1):length(beta)]
     
#   if(form == ''){
 #    beta <- beta[-1]
  #   .Sigma <- .Sigma[-1,-1]
   #}
   
   
   keep <- !is.na(beta)
   .Sigma <- .Sigma[keep, keep]
   beta <- beta[keep]
   betas <- mvrnorm(n = nsims, mu=beta, Sigma = .Sigma, tol = 1e-6, empirical = FALSE, EISPACK = FALSE)
   
   #  resids <- resid(felm(as.formula(paste0('t_est ~ 0', form)), data=covars_est))
   
   ypreds <- t_est[,keep] %*% t(betas)
   #if(!is.null(covars_est)){
  #   ypreds <- resid(felm(as.formula(paste0('ypreds~t_est',form)), data=covars_est))
  # }else{
  #   ypreds <- resid(felm(as.formula(paste0('ypreds~t_est',form))))
  # }
   pmin <- apply(ypreds, 2, function(x){x <= quantile(x, q)})/nsims
   pmax <- apply(ypreds, 2, function(x){x >= quantile(x, 1-q)})/nsims
   
   est <- get_mcseq_est_set_het_ls(pmax, pmin, y_est, covars_est, nsims, form, treat_est=treat_est)
   if(is.null(clust)){  
     boots <- unlist(lapply(1:nboot, boot_iid, est_fun=get_mcseq_est_set_het_ls, pmax=pmax, 
                            pmin=pmin, y_est=y_est, covars=covars_est, nsims=nsims, form=form, treat_est=treat_est))
   }else{
     boots <- unlist(lapply(1:nboot, boot_clust, est_fun=get_mcseq_est_set_het_ls, clust=clust,pmax=pmax, 
                            pmin=pmin, y_est=y_est, covars=covars_est, nsims=nsims,form=form, treat_est=treat_est))
     
   }
   return(list(est=est, v=var(boots)))
 }
 
 
 library(hierNet)
 est_treat_mcseq_post_lasso <- function(t_split, t_est, y_split, y_est, clust, form, covars_est, covars_split, nsims, nboot, q, treat_est=NULL, treat_split=NULL){
   resid_y <- resid(felm(as.formula(paste0('y_split ~ 0', form)), covars_split))
   resid_X <- resid(felm(as.formula(paste0('t_split ~ 0', form)), covars_split))
   library(doParallel)
   registerDoParallel(5)
   
   sel_mod <- cv.glmnet(x=resid_X, y=resid_y, nfolds=5, parallel=T)
#   sel_mod <- rlasso(x=resid_X, y=resid_y)
   
   t_split <- t_split[,as.numeric(coef(sel_mod, s='lambda.min'))[-1]!=0]
   t_est <- t_est[,as.numeric(coef(sel_mod, s='lambda.min'))[-1]!=0]
   split_mod <- felm(as.formula(paste0('y_split ~ t_split', form)), covars_split)
   

   beta <- coef(split_mod)
   .Sigma <- vcov(split_mod)
   
   if(form == ''){
     beta <- beta[-1]
     .Sigma <- .Sigma[-1,-1]
   }
   
   keep <- !is.na(beta)
   .Sigma <- .Sigma[keep, keep]
   beta <- beta[keep]
   
   betas <- mvrnorm(n = nsims, mu=beta, Sigma = .Sigma, tol = 1e-6, empirical = FALSE, EISPACK = FALSE)
   
   #  resids <- resid(felm(as.formula(paste0('t_est ~ 0', form)), data=covars_est))
   
   ypreds <- t_est[,keep] %*% t(betas)
   ypreds <- resid(felm(as.formula(paste0('ypreds~0',form)), data=covars_est))
   
   pmin <- apply(ypreds, 2, function(x){x <= quantile(x, q)})/nsims
   pmax <- apply(ypreds, 2, function(x){x >= quantile(x, 1-q)})/nsims
   
   est <- get_mcseq_est_set_ls(pmax, pmin, y_est, covars_est, nsims, form)
   if(is.null(clust)){  
     boots <- unlist(lapply(1:nboot, boot_iid, est_fun=get_mcseq_est_set_ls, pmax=pmax, 
                            pmin=pmin, y_est=y_est, covars=NULL, nsims=nsims,form=form))
   }else{
     boots <- unlist(lapply(1:nboot, boot_clust, est_fun=get_mcseq_est_set_ls, clust=clust,pmax=pmax, 
                            pmin=pmin, y_est=y_est, covars=covars_est, nsims=nsims,form=form))
     
   }
   return(list(est=est, v=var(boots)))
 }

 
 library(hierNet)
 est_treat_mcseq_hier_lasso2 <- function(t_split, t_est, y_split, y_est, clust, form, covars_est, covars_split, nsims, nboot, q){
    resid_y <- resid(felm(as.formula(paste0('y_split ~ 0', form)), covars_split))
    resid_X <- resid(felm(as.formula(paste0('t_split ~ 0', form)), covars_split))
    library(doParallel)
    registerDoParallel(5)
    
    sel_mod1 <- cv.glmnet(x=resid_X, y=resid_y, nfolds=5, parallel=T)
    if(sel_mod1$nzero[sel_mod1$lambda==sel_mod1$lambda.min] ==0){
         return(est_treat_mcseq_ls(t_split, t_est, y_split, y_est, clust, form, covars_est, covars_split, nsims, nboot, q))       
    }
    
    interacted <- model.matrix(~.*., data=as.data.frame(resid_X[,as.numeric(coef(sel_mod1, s='lambda.min'))[-1]!=0]))
    interacted <- resid(felm(as.formula(paste0('interacted~t_split', form)),data=covars_split))
    resid_y <- resid(felm(as.formula(paste0('y_split~t_split', form)),data=covars_split))
    
    
    sel_mod2 <- cv.glmnet(x=interacted, y=resid_y, nfolds=5, parallel=T)
   
    if(sel_mod2$nzero[sel_mod2$lambda==sel_mod2$lambda.min] >0){
       int_kept <- model.matrix(~.*., data=as.data.frame(t_split[,as.numeric(coef(sel_mod1, s='lambda.min'))[-1]!=0]))[]
       t_split <- cbind(t_split[,as.numeric(coef(sel_mod1, s='lambda.min'))[-1]!=0], int_kept[,as.numeric(coef(sel_mod2, s='lambda.min'))[-1]!=0])
       int_est <- model.matrix(~.*., data=as.data.frame(t_est[,as.numeric(coef(sel_mod1, s='lambda.min'))[-1]!=0]))
       t_est <- cbind(t_est[,as.numeric(coef(sel_mod1, s='lambda.min'))[-1]!=0], int_est[,as.numeric(coef(sel_mod2, s='lambda.min'))[-1]!=0])
       
    }else{
       t_split <- t_split[,as.numeric(coef(sel_mod1, s='lambda.min'))[-1]!=0]
       t_est <- t_est[,as.numeric(coef(sel_mod1, s='lambda.min'))[-1]!=0]
       
    }    
    
        
    split_mod <- felm(as.formula(paste0('y_split ~ t_split', form)), covars_split)
    
    beta <- coef(split_mod)
    .Sigma <- vcov(split_mod)
    
    keep <- !is.na(beta)
    .Sigma <- .Sigma[keep, keep]
    beta <- beta[keep]
    betas <- mvrnorm(n = nsims, mu=beta, Sigma = .Sigma, tol = 1e-6, empirical = FALSE, EISPACK = FALSE)
    
    #  resids <- resid(felm(as.formula(paste0('t_est ~ 0', form)), data=covars_est))
    if(is.vector(t_est)){
      ypreds <- do.call(cbind,lapply(t(betas), function(x){t_est*beta}))
    }else{
       ypreds <- t_est[,keep] %*% t(betas)
       ypreds <- resid(felm(as.formula(paste0('ypreds~0',form)), data=covars_est))
    }
    
    pmin <- apply(ypreds, 2, function(x){x <= quantile(x, q)})/nsims
    pmax <- apply(ypreds, 2, function(x){x >= quantile(x, 1-q)})/nsims
    
    est <- get_mcseq_est_set_ls(pmax, pmin, y_est, covars_est, nsims, form)
    if(is.null(clust)){  
       boots <- unlist(lapply(1:nboot, boot_iid, est_fun=get_mcseq_est_set_ls, pmax=pmax, 
                              pmin=pmin, y_est=y_est, covars, nsims=nsims,form=form))
    }else{
       boots <- unlist(lapply(1:nboot, boot_clust, est_fun=get_mcseq_est_set_ls, clust=clust,pmax=pmax, 
                              pmin=pmin, y_est=y_est, covars=covars_est, nsims=nsims,form=form))
       
    }
    return(list(est=est, v=var(boots)))
 }
 
 
 library(hierNet)
 est_treat_mcseq_post_hier_lasso <- function(t_split, t_est, y_split, y_est, clust, form, covars_est, covars_split, nsims, nboot, q){
   resid_y <- resid(felm(as.formula(paste0('y_split ~ 0', form)), covars_split))
   resid_X <- resid(felm(as.formula(paste0('t_split ~ 0', form)), covars_split))
  # library(doParallel)
  # registerDoParallel(5)
   
   fit=hierNet.path(x=resid_X, y=c(resid_y), delta=0)
   fitcv=hierNet.cv(fit,resid_X,resid_y, nfolds=5)
   
   
   interact <- ((fit$th[,,fit$lamlist== fitcv$lamhat]+ t(fit$th[,,fit$lamlist== fitcv$lamhat]))/2)
   keep_interact <- interact != 0
   
   t_split <- as.data.frame(t_split)
   t_est <- as.data.frame(t_est)
   
   names <- colnames(t_split)
   mains <- names[fit$bp[,fit$lamlist== fitcv$lamhat] - fit$bn[,fit$lamlist== fitcv$lamhat] != 0]
   
   names_mat <- matrix(rep(names, ncol(resid_X)),ncol=ncol(resid_X))
   interacts <- paste0(names_mat, ':', t(names_mat))[upper.tri(keep_interact)]
   
   quads <- paste0(names[diag(interact)!=0], '^2')
   
   t_split <- model.matrix(as.formula(paste0(
     '~', paste0(c(mains, quads, interacts), collapse='+'))), data=t_split)
   
   t_est <- model.matrix(as.formula(paste0(
     '~', paste0(c(mains, quads, interacts), collapse='+'))), data=t_est)
   
   split_mod <- felm(as.formula(paste0('y_split ~ t_split', form)), covars_split)
   
   beta <- coef(split_mod)
   .Sigma <- vcov(split_mod)
   if(form == ''){
     beta <- beta[-1]
     .Sigma <- .Sigma[-1,-1]
   }
   keep <- !is.na(beta)
   .Sigma <- .Sigma[keep, keep]
   beta <- beta[keep]
   betas <- mvrnorm(n = nsims, mu=beta, Sigma = .Sigma, tol = 1e-6, empirical = FALSE, EISPACK = FALSE)
   
   #  resids <- resid(felm(as.formula(paste0('t_est ~ 0', form)), data=covars_est))
   
   ypreds <- t_est[,keep] %*% t(betas)
   ypreds <- resid(felm(as.formula(paste0('ypreds~0',form)), data=covars_est))
   
   pmin <- apply(ypreds, 2, function(x){x <= quantile(x, q)})/nsims
   pmax <- apply(ypreds, 2, function(x){x >= quantile(x, 1-q)})/nsims
   
   est <- get_mcseq_est_set_ls(pmax, pmin, y_est, covars_est, nsims, form)
   if(is.null(clust)){  
     boots <- unlist(lapply(1:nboot, boot_iid, est_fun=get_mcseq_est_set_ls, pmax=pmax, 
                            pmin=pmin, y_est=y_est, covars, nsims=nsims,form=form))
   }else{
     boots <- unlist(lapply(1:nboot, boot_clust, est_fun=get_mcseq_est_set_ls, clust=clust,pmax=pmax, 
                            pmin=pmin, y_est=y_est, covars=covars_est, nsims=nsims,form=form))
     
   }
   return(list(est=est, v=var(boots)))
 }
 
 
est_mcseq_2_samp <- function(treat_types, y, covars=NULL, nboot=100, nsims=10, q, clust, est_fun, form=NULL, treatment=NULL, n_num_covars=0){
  if(is.null(clust)){
    folds <- sample(1:2, nrow(treat_types), replace=T)
  }else{
    fold1 <- sample(unique(clust), length(unique(clust))/2, replace=F)
    folds <- as.numeric(clust %in% fold1) + 1
  }
  
  ests <- rep(NA, 2)
  vars <- rep(NA, 2)
  for(fold in 1:2){
    mod <- est_fun(t_split=treat_types[folds !=fold,], t_est=treat_types[folds==fold,],
                   y_split=y[folds != fold], y_est=y[folds == fold], 
                   clust=clust[folds==fold], form, covars_est=covars[folds==fold,], 
                   covars_split=covars[folds!=fold,], nsims, nboot, q, 
                   treat_est =treatment[folds==fold], treat_split=treatment[folds!=fold], n_num_covars=n_num_covars)

    ests[fold] <- mod$est
    vars[fold] <- mod$v
  }
  return(list(est=mean(ests), se=sqrt(mean(vars))))
}

est_mcseq_2_samp_wrapped <- function(treat_types, y, covars, nboot=100, nsims=10, q, clust, est_fun, form, wraps,n_num_covars=0){
  ests <- c()
  vars <- c()
  for(i in 1:wraps){
    mod <- est_mcseq_2_samp(treat_types, y, covars, nboot=100, nsims=10, q, clust, est_fun, form=form,n_num_covars=n_num_covars)
    ests <- c(ests, mod[[1]])
    vars <- c(vars, mod[[2]]^2)
  }
  gc()
  return(list(est=mean(ests), se=sqrt(mean(vars))))

}
